{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from transformers import ChineseCLIPProcessor, ChineseCLIPModel, Trainer,TrainingArguments\n",
    "import torch\n",
    "from datasets import load_dataset,Dataset\n",
    "from pathlib import Path\n",
    "import pandas as pd \n",
    "import numpy as np \n",
    "from tqdm import tqdm "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_or_path = \"OFA-Sys/chinese-clip-vit-large-patch14\"\n",
    "model = ChineseCLIPModel.from_pretrained(model_name_or_path)\n",
    "processor = ChineseCLIPProcessor.from_pretrained(model_name_or_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [[101, 671, 702, 3749, 6756, 102], [101, 2408, 691, 5546, 4649, 4173, 7890, 976, 2533, 5401, 679, 5401, 6821, 1126, 4157, 2408, 2466, 4173, 7890, 976, 3791, 4908, 6394, 2523, 7028, 6206, 102], [101, 2458, 2552, 7937, 5709, 1920, 4510, 2512, 2137, 3440, 517, 1909, 3821, 4294, 4172, 2630, 518, 6856, 3683, 5439, 3926, 3173, 102], [101, 6617, 749, 6821, 671, 1767, 4801, 801, 117, 5074, 5381, 1377, 809, 2128, 2552, 5023, 3616, 3152, 1726, 3341, 749, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_str = ['一个汽车', '广东脆皮烧鸭做得美不美 这几点广式烧鸭做法秘诀很重要', '开心麻花大电影定档 《夏洛特烦恼》逗比老清新',\"赢了这一场硬仗,篮网可以安心等欧文回来了\"]\n",
    "processor(text=text_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pixel_values': [array([[[ 0.14932826,  0.14932826,  0.14932826, ..., -0.55139625,\n",
       "         -0.5659947 , -0.5805931 ],\n",
       "        [ 0.14932826,  0.14932826,  0.14932826, ..., -0.55139625,\n",
       "         -0.5659947 , -0.5805931 ],\n",
       "        [ 0.14932826,  0.14932826,  0.14932826, ..., -0.55139625,\n",
       "         -0.55139625, -0.5659947 ],\n",
       "        ...,\n",
       "        [-0.20103405, -0.12804192,  0.07633615, ..., -1.0331444 ,\n",
       "         -0.7995695 , -0.6973805 ],\n",
       "        [-0.24482933, -0.47840413, -0.5805931 , ..., -0.9893491 ,\n",
       "         -0.94555384, -0.78497106],\n",
       "        [-0.15723877, -0.43460885, -0.60978997, ..., -0.7703726 ,\n",
       "         -0.8579632 , -0.78497106]],\n",
       "\n",
       "       [[ 1.0843711 ,  1.0543556 ,  1.0393478 , ...,  0.55909926,\n",
       "          0.54409146,  0.5290837 ],\n",
       "        [ 1.0543556 ,  1.0543556 ,  1.0393478 , ...,  0.55909926,\n",
       "          0.54409146,  0.5290837 ],\n",
       "        [ 1.0543556 ,  1.0393478 ,  1.0393478 , ...,  0.55909926,\n",
       "          0.55909926,  0.54409146],\n",
       "        ...,\n",
       "        [-0.10124261, -0.04121154,  0.15388943, ..., -0.82161546,\n",
       "         -0.53646785, -0.40139794],\n",
       "        [-0.13125814, -0.38639018, -0.49144456, ..., -0.7315688 ,\n",
       "         -0.65652996, -0.461429  ],\n",
       "        [ 0.01881953, -0.2813358 , -0.4764368 , ..., -0.461429  ,\n",
       "         -0.53646785, -0.4164057 ]],\n",
       "\n",
       "       [[ 1.6055346 ,  1.6624149 ,  1.7477353 , ...,  1.6339747 ,\n",
       "          1.6197547 ,  1.6055346 ],\n",
       "        [ 1.6624149 ,  1.705075  ,  1.7619553 , ...,  1.6339747 ,\n",
       "          1.6197547 ,  1.6055346 ],\n",
       "        [ 1.7477353 ,  1.7619553 ,  1.8046155 , ...,  1.6339747 ,\n",
       "          1.6339747 ,  1.6197547 ],\n",
       "        ...,\n",
       "        [-0.49903518, -0.3710546 , -0.10087335, ..., -0.21463387,\n",
       "          0.0413273 ,  0.16930789],\n",
       "        [-0.48481512, -0.6412359 , -0.68389606, ..., -0.14353354,\n",
       "         -0.08665328,  0.09820756],\n",
       "        [-0.32839438, -0.5274753 , -0.6412359 , ...,  0.09820756,\n",
       "          0.01288717,  0.1266477 ]]], dtype=float32)]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image_str = \"bigdata/image_data/test-31115.jpg\"\n",
    "image_input = Image.open(image_str)\n",
    "processor(images=image_input)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7784adb47ed848a89f99c6cf5da51697",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1346 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22ef6a2c8f63478f985017e8149186e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['image_path', 'text', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 1345373\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['image_path', 'text', 'input_ids', 'attention_mask'],\n",
       "        num_rows: 270\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tokenizer_text(examples) :\n",
    "    res = processor(text=examples['text'],max_length=64, padding=\"max_length\", return_tensors=\"pt\", truncation=True)\n",
    "    examples['input_ids'] = res['input_ids']\n",
    "    examples['attention_mask'] = res['attention_mask']\n",
    "    return examples\n",
    "\n",
    "def transform_images(examples):\n",
    "    images = [Image.open(i) for i in examples['image_path']]\n",
    "    images = [processor(images=i, return_tensors=\"pt\")['pixel_values'] for i in images]\n",
    "    examples['pixel_values'] = images\n",
    "    return examples\n",
    "\n",
    "\n",
    "dataset = Dataset.from_pandas(df=pd.read_csv(\"bigdata/clean_train_test/train.csv\"))\n",
    "dataset = dataset.train_test_split(test_size=0.0002)\n",
    "dataset = dataset.map(\n",
    "    function=tokenizer_text,\n",
    "    batched=True\n",
    ")\n",
    "# dataset = dataset.map(\n",
    "#     function=transform_images,\n",
    "#     batched=True\n",
    "# )\n",
    "\n",
    "dataset.set_transform(transform=transform_images)\n",
    "\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 224, 224])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dataset['train'][2]['input_ids']\n",
    "dataset['train'][3]['pixel_values'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:07<00:00, 28.25it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{torch.Size([64])}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = set(torch.tensor(dataset['train'][i]['input_ids']).shape for i in tqdm(range(200)))\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [00:02<00:00, 77.88it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{torch.Size([64])}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = set(torch.tensor(dataset['train'][i]['attention_mask']).shape for i in tqdm(range(200)))\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 64])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([torch.tensor(dataset['train'][0]['input_ids']),\n",
    "             torch.tensor(dataset['train'][0]['input_ids'])]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 224, 224])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'][0]['pixel_values'].squeeze(0).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 224, 224])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([dataset['train'][0]['pixel_values'].squeeze(0),\n",
    "             dataset['train'][0]['pixel_values'].squeeze(0),]).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for i in tqdm(range(dataset['train'].__len__())):\n",
    "#     try:\n",
    "#         image = dataset['train'][i]['image_path']\n",
    "#         image = Image.open(image)\n",
    "#         processor(images=image, return_tensors=\"pt\")\n",
    "#     except Exception as e:\n",
    "#         print(dataset['train'][i]['image_path'])\n",
    "#         break\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset['train'][5374]#['image_path']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pixel_values': tensor([[[[ 0.2515,  0.4705,  0.4559,  ...,  0.0179,  0.0033,  0.0033],\n",
       "          [ 0.3683,  0.5873,  0.6457,  ...,  0.0033, -0.0113, -0.0113],\n",
       "          [ 0.3975,  0.4705,  0.6019,  ...,  0.0179,  0.0617,  0.0763],\n",
       "          ...,\n",
       "          [ 0.8792,  0.7041,  0.4121,  ...,  1.1858,  1.2296,  1.1712],\n",
       "          [ 0.5581,  0.4413,  0.2953,  ...,  1.1712,  1.1858,  1.2004],\n",
       "          [-0.0988, -0.1280, -0.1572,  ...,  1.1128,  1.1566,  1.1712]],\n",
       "\n",
       "         [[ 0.3340,  0.5591,  0.5441,  ...,  0.0488,  0.0338,  0.0338],\n",
       "          [ 0.4540,  0.6792,  0.7392,  ...,  0.0338,  0.0188,  0.0188],\n",
       "          [ 0.4841,  0.5591,  0.6942,  ...,  0.0488,  0.0939,  0.1089],\n",
       "          ...,\n",
       "          [ 1.1594,  0.9793,  0.6792,  ...,  1.2344,  1.2795,  1.2194],\n",
       "          [ 0.7842,  0.6642,  0.5141,  ...,  1.2194,  1.2344,  1.2495],\n",
       "          [ 0.0789,  0.0488,  0.0338,  ...,  1.1594,  1.2044,  1.2194]],\n",
       "\n",
       "         [[ 0.5675,  0.7808,  0.7666,  ...,  0.1835,  0.1693,  0.1693],\n",
       "          [ 0.6812,  0.8945,  0.9514,  ...,  0.1693,  0.1551,  0.1551],\n",
       "          [ 0.7097,  0.7808,  0.9088,  ...,  0.1835,  0.2262,  0.2404],\n",
       "          ...,\n",
       "          [ 1.5344,  1.3638,  1.0794,  ...,  1.4349,  1.4776,  1.4207],\n",
       "          [ 1.1932,  1.0794,  0.9372,  ...,  1.4207,  1.4349,  1.4491],\n",
       "          [ 0.5248,  0.4964,  0.4821,  ...,  1.3638,  1.4065,  1.4207]]]])}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "image = dataset['train'][5374]['image_path']\n",
    "image = Image.open(image)\n",
    "processor(images=image, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "PyTorch: setting up devices\n",
      "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
      "Using cuda_amp half precision backend\n",
      "***** Running training *****\n",
      "  Num examples = 1345373\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 28\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 56\n",
      "  Gradient Accumulation steps = 2\n",
      "  Total optimization steps = 24025\n",
      "  Number of trainable parameters = 406233089\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "535c6c4fb6b240318dd77253bae148cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/24025 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (98988480 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3329, 'learning_rate': 0.0004978699183320599, 'epoch': 0.04}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97829a396eef48c999a2c6362e87a17c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-1000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-1000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 9.1508, 'eval_samples_per_second': 29.506, 'eval_steps_per_second': 1.093, 'epoch': 0.04}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-1000\\pytorch_model.bin\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.0004915075250885755, 'epoch': 0.08}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8be40c91f0204fd7adc23bdb49e7bea4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-2000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-2000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.4482, 'eval_samples_per_second': 78.302, 'eval_steps_per_second': 2.9, 'epoch': 0.08}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-2000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-20] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (138921744 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.000481033946311067, 'epoch': 0.12}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "706f6070730640beade23c641cdd53b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-3000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-3000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.3923, 'eval_samples_per_second': 79.592, 'eval_steps_per_second': 2.948, 'epoch': 0.12}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-3000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-40] due to args.save_total_limit\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3322, 'learning_rate': 0.0004666070773952711, 'epoch': 0.17}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "033a9949932e4aef9e10a9db7e10c482",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-4000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-4000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.3887, 'eval_samples_per_second': 79.676, 'eval_steps_per_second': 2.951, 'epoch': 0.17}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-4000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-60] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.00044850157434820096, 'epoch': 0.21}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d33909f7bfb4ca9a73631b4a7a205cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-5000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-5000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.3772, 'eval_samples_per_second': 79.949, 'eval_steps_per_second': 2.961, 'epoch': 0.21}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-5000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-1000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.000426990388338636, 'epoch': 0.25}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "25c610e247d64758aaffbc9a7e133334",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-6000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-6000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.4124, 'eval_samples_per_second': 79.122, 'eval_steps_per_second': 2.93, 'epoch': 0.25}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-6000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-2000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.0004024830452650919, 'epoch': 0.29}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c206738fd00496284f115a763530e49",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-7000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-7000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3026041984558105, 'eval_runtime': 3.7669, 'eval_samples_per_second': 71.676, 'eval_steps_per_second': 2.655, 'epoch': 0.29}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-7000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-3000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.4755, 'learning_rate': 0.00037537729147642094, 'epoch': 0.33}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7dfb4bd92c404859aaacdd1f742c81ab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-8000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-8000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 27.78020477294922, 'eval_runtime': 3.3844, 'eval_samples_per_second': 79.777, 'eval_steps_per_second': 2.955, 'epoch': 0.33}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-8000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-4000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3596, 'learning_rate': 0.00034610484869133153, 'epoch': 0.37}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fcbde947a87a4a55b8052f97953a8c5c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-9000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-9000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.301671266555786, 'eval_runtime': 3.3943, 'eval_samples_per_second': 79.546, 'eval_steps_per_second': 2.946, 'epoch': 0.37}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-9000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-5000] due to args.save_total_limit\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 270\n",
      "  Batch size = 28\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 3.3321, 'learning_rate': 0.00031519144087211273, 'epoch': 0.42}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "305bc3ed12f74f5cafc8d1899f764421",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to clip_chinese_02\\checkpoint-10000\n",
      "Configuration saved in clip_chinese_02\\checkpoint-10000\\config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 3.3030922412872314, 'eval_runtime': 3.3937, 'eval_samples_per_second': 79.56, 'eval_steps_per_second': 2.947, 'epoch': 0.42}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in clip_chinese_02\\checkpoint-10000\\pytorch_model.bin\n",
      "Deleting older checkpoint [clip_chinese_02\\checkpoint-6000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:979: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "def collate_fn(examples):\n",
    "    pixel_values = torch.stack([i['pixel_values'].squeeze(0) for i in examples])\n",
    "    input_ids = torch.tensor([example[\"input_ids\"] for example in examples], dtype=torch.long)#torch.stack([torch.tensor(i, dtype=torch.long) for i in examples['input_ids']])\n",
    "    attention_mask = torch.tensor([example[\"attention_mask\"] for example in examples], dtype=torch.long)#torch.stack([torch.tensor(i) for i in examples['attention_mask']])\n",
    "    return {\n",
    "        \"pixel_values\": pixel_values,\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"return_loss\": True,\n",
    "    }\n",
    "\n",
    "\n",
    "train_argument = TrainingArguments(\n",
    "    output_dir=\"clip_chinese_02\",\n",
    "    per_device_train_batch_size=28,\n",
    "    per_device_eval_batch_size=28,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=1000,\n",
    "    logging_steps=1000,\n",
    "    gradient_accumulation_steps=2,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=1000,\n",
    "    fp16=True,\n",
    "    remove_unused_columns=False,\n",
    "    save_total_limit=4\n",
    "\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=train_argument,\n",
    "    train_dataset=dataset['train'],\n",
    "    eval_dataset=dataset['test'],\n",
    "    data_collator=collate_fn,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset['train'][5374]#['image_path']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "110bc624a448454d574a0cd6cc76359fd86f75739e493913b3d71c2e04f2ffdb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
